import pickle

import bnlearn as bn
import numpy as np


def save_datasets(SAVE_DATASET, label_save_dir, feature, true_data):
    if SAVE_DATASET == False:
        return

    for label in true_data:
        file_name = label_save_dir + label  + ".pkl"
        with open(file_name, 'wb') as fp:
            pickle.dump(np.array(true_data[label]), fp)
        print(file_name, " saved")


# bif_file = 'sprinkler'
# bif_file = 'alarm'
# bif_file = 'andes'
bif_file = 'asia'
# bif_file = 'pathfinder'
# bif_file = 'asia'
# bif_file = 'miserables'

# Loading DAG with model parameters from bif file.

model = bn.import_DAG(bif_file)
print(model)
G = bn.plot(model)

df = bn.sampling(model, n=20000)
dataset= df.to_numpy()
label_names = df.columns

label_save_dir="data_path"
for ind, label in  enumerate(label_names):
    file_name = label_save_dir +"intv0" + label + ".pkl"
    with open(file_name, 'wb') as fp:
        each_col= np.array(dataset[:,ind:ind+1])
        pickle.dump(each_col, fp)
    print(file_name, " saved")


# file_name = file_root + f"sachs_P({prob_str}|do({intv_var}={key-1})).txt"
# with open(file_name , 'wb') as fp:
#     pickle.dump(np.array(do_data), fp)